import wmi
import time
import datetime
import logging
from collections import OrderedDict, defaultdict
from operator import attrgetter, itemgetter
import defaults

logger = logging.getLogger('root')

class Computer(object):

    def __init__(self, hostName, scanDate):
        logger.info('New host: {}'.format(hostName[1]))
        self.classes = []
        self.hostName, self.shortName = hostName
        self.scanDate = scanDate
        self.summary = OrderedDict()

    def connect(self, uname=None, pword=None):
        """ Connect to WMI instance on host. """
        logger.info('Connecting to host.')
        try:
            if uname:
                self.wmiInstance = wmi.WMI(self.hostName, user=uname, password=pword)
            else:
                self.wmiInstance = wmi.WMI()
        except wmi.x_access_denied:
            logger.error('Access is denied to host.')
            raise
        except Exception as e:
            logger.error('Error connecting to host: {}'.format(e))
            return False
        else:
            return True
        finally:
            del uname
            del pword

    def destroy_connection(self):
        """ Deletes instance variable, garbage collection will remove the instance. """
        logger.info('Tearing down connection.')
        del self.wmiInstance

    def class_handler(self, classToHandle, className, *args):
        """ Helper method for WMI classes. """
        logger.info(className)
        try:
            cls = classToHandle(self.wmiInstance, *args)
        except AttributeError as e:
            logger.error('There was an error with WMI: {}'.format(e))
        except Exception as e:
            logger.error('Error retrieving information: {}'.format(e))
            return NoData(className)
        else:
            setattr(cls, 'className', className)
            self.classes.append(className)
            return cls

    def get_disks(self):
        self.disks = self.class_handler(_LogicalDisks, 'Disk')
        self.summary['Disk'] = self.disks.output()

    def get_shares(self):
        self.shares = self.class_handler(_NetworkShares, 'Shares')
        self.summary['Shares'] = self.shares.output()

    def get_system(self):
        self.bios = self.class_handler(_Bios, 'BIOS')
        self.processor = self.class_handler(_Processor, 'Processor')
        self.system = self.class_handler(_ComputerSystem, 'ComputerSystem')
        self.video = self.class_handler(_Video, 'Video')
        self.system.serialnumber = self.bios.serialnumber
        self.system.attr_names.append('serialnumber')
        self.summary['BIOS'] = self.bios.output()
        self.summary['Processor'] = self.processor.output()
        self.summary['System'] = self.system.output()
        self.summary['Video'] = self.video.output()

    def get_os(self):
        self.os = self.class_handler(_OS, 'OperatingSystem')
        self.services = self.class_handler(_Services, 'Services')
        self.processes = self.class_handler(_RunningProcesses, 'RunningProcesses')
        self.startup = self.class_handler(_Startup, 'Startup')
        self.summary['OS'] = self.os.output()
        self.summary['Processes'] = self.processes.output()
        self.summary['Services'] = self.services.output()
        self.summary['Startup'] = self.startup.output()

    def get_printers(self):
        self.printers = self.class_handler(_Printers, 'Printers')
        self.summary['Printers'] = self.printers.output()

    def get_events(self, history):
        self.events = self.class_handler(_EventLog, 'EventLog', history)
        return self.events.count

    def build_summary(self):
        """ Create dictionary hostName: (summary, disk, events, printers).
        This is updated into the summary dictionary in the scan manager."""
        summary = [[self.hostName, self.system.manufacturer, self.system.model,
                    self.system.serialnumber, self.processor.name,
                    self.processor.numberofcores, self.system.totalphysicalmemory,
                    self.system.username, self.os.caption, self.os.osarchitecture, 
                    self.os.servicepackmajorversion, self.os.installdate,
                    len(self.processes.list)]]
        disk = ('Disk' in self.classes and [[self.hostName, d[0], d[1]] for d in self.disks.diskAlert[1:]]) \
            or []
        events = [[self.hostName, self.events.count]]
        printers = ('Printers' in self.classes and [[self.hostName] + d for d in self.printers.alert[1:]]) \
            or []
        return {self.hostName: {'shortname': self.shortName,
                                'summary': summary,
                                'disk': disk,
                                'events': events,
                                'printers': printers}}

class _DeriveMe(object):
    def set_attrs(self, obj, attr_names):
        map(lambda n: setattr(self, n, getattr(obj, n, None)), attr_names)
    def get_attrs(self, obj, attr_names):
        return [ getattr(obj, attribute, None) for attribute in attr_names ]
    def format_date(self, date):
        return time.strptime(date.split('.')[0], '%Y%m%d%H%M%S')
    def format_date_output(self, date):
        return time.strftime('%Y-%m-%d %H:%M:%S', date)
    def get_data_format(self):
        return getattr(self, 'dataFormat', 'obj')
    def output(self):
        logger.info('Generating output for {}'.format(self.className))
        form = self.get_data_format()
        output = []
        output.append([ getattr(self, 'className', '') ])
        if form == 'obj':
            output.append(self.attr_names)
            output.append(self.get_attrs(self, self.attr_names))
        elif form == 'flat':
            output.append(self.attr_names)
            output.extend(self.list)
        elif form == 'list':
            output.append(self.list[0].attr_names)
            output.extend([i.get_attrs(i, i.attr_names) for i in self.list])
        return output
    def __str__(self):
        names = getattr(self, 'attr_names')
        attrs = [ getattr(self, attr) for attr in names ]
        return '\n'.join([ '{}: {}'.format(x, y) for x, y in zip(names, attrs) ])

class NoData(_DeriveMe):

    def __init__(self, className):
        self.className = className
        self.attr_names = []
        self.dataFormat = 'flat'
        self.list = []

class BadAttribute(AttributeError):
    pass


class _Bios(_DeriveMe):

    def __init__(self, wmiInstance):
        obj = wmiInstance.Win32_BIOS()[0]
        self.attr_names = ['serialnumber', 'smbiosbiosversion']
        self.set_attrs(obj, self.attr_names)


class _Processor(_DeriveMe):

    def __init__(self, wmiInstance):
        obj = wmiInstance.Win32_Processor()[0]
        self.attr_names = ['name', 'numberofcores', 'numberoflogicalprocessors']
        self.set_attrs(obj, self.attr_names)


class _ComputerSystem(_DeriveMe):

    def __init__(self, wmiInstance):
        obj = wmiInstance.Win32_ComputerSystem()[0]
        self.attr_names = ['manufacturer', 'model', 'partofdomain',
                           'numberofprocessors', 'totalphysicalmemory',
                           'username']
        self.set_attrs(obj, self.attr_names)


class _NetworkShares(_DeriveMe):

    def __init__(self, wmiInstance):
        self.wmiShares = wmiInstance.Win32_Share()
        self.list = []
        self.dataFormat = 'list'
        self.get_all_shares()

    # BEGIN inner class for individual shares
    class Share(_DeriveMe):
        def __init__(self, share):
            self.className = 'Share'
            self.attr_names = ['name', 'Path', 'caption']
            self.set_attrs(share, self.attr_names)
    # END inner class

    def get_all_shares(self):
        for share in self.wmiShares:
            logger.info('Found share: {}'.format(share.name))
            newShare = self.Share(share)
            self.list.append(newShare)


class _Video(_DeriveMe):

    def __init__(self, wmiInstance):
        self.wmiVideo = wmiInstance.Win32_VideoController()
        self.list = []
        self.dataFormat = 'list'
        self.get_all_video()

    # BEGIN inner class for video adapters
    class Adapter(_DeriveMe):
        def __init__(self, adapter):
            self.className = 'Adapter'
            self.attr_names = ['adaptercompatibility', 'caption',
                               'driverversion']
            self.set_attrs(adapter, self.attr_names)
    # END inner class

    def get_all_video(self):
        for adapter in self.wmiVideo:
            logger.info('Found video: {}'.format(adapter.caption))
            newAdapter = self.Adapter(adapter)
            self.list.append(newAdapter)


class _OS(_DeriveMe):

    def __init__(self, wmiInstance):
        obj = wmiInstance.Win32_OperatingSystem()[0]
        self.attr_names = ['caption', 'osarchitecture',
                           'servicepackmajorversion', 'sizestoredinpagingfiles',
                           'installdate']
        self.set_attrs(obj, self.attr_names)
        self.installdate = self.format_date_output(self.format_date(self.installdate))


class _Printers(_DeriveMe):

    def __init__(self, wmiInstance):
        self.list = []
        self.alert = [['Printer', 'Status', 'Error State', 'Extended State']]
        self.dataFormat = 'list'
        self.wmiPrinters = wmiInstance.Win32_Printer()
        self.get_all_printers()

    # BEGIN inner class for printer
    class Printer(_DeriveMe):
        def __init__(self, printer):
            self.className = 'Printer'
            self.attr_names = ['name', 'sharename', 'portname', 'printerstatus',
                               'detectederrorstate', 'extendeddetectederrorstate']
            self.set_attrs(printer, self.attr_names)
            self.get_printer_status()
        def get_printer_status(self):
            self.printerstatus = defaults.PRINTER_STATUS.get(self.printerstatus, self.printerstatus)
            self.detectederrorstate = defaults.PRINTER_ERROR_STATUS.get(self.detectederrorstate,
                                                                        self.detectederrorstate)
            self.extendeddetectederrorstate = defaults.PRINTER_EXTENDED_STATUS.get(self.extendeddetectederrorstate,
                                                                                   self.extendeddetectederrorstate)
    # END inner class

    def get_all_printers(self):
        normalOpList = ['Idle', 'Printing', 'Warming Up']
        for printer in self.wmiPrinters:
            logger.info('Found printer: {}'.format(printer.name))
            newPrinter = self.Printer(printer)
            if newPrinter.printerstatus not in normalOpList:
                self.alert.append([newPrinter.name, newPrinter.printerstatus,
                                   newPrinter.detectederrorstate,
                                   newPrinter.extendeddetectederrorstate])
            self.list.append(newPrinter)


class _Services(_DeriveMe):

    def __init__(self, wmiInstance):
        self.wmiServices = wmiInstance.Win32_Service()
        self.attr_names = ['caption', 'name', 'state', 'startmode', 'startname', 'pathname',
                           'acceptpause', 'acceptstop']
        self.list = []
        self.dataFormat = 'flat'
        self.get_all_services()
        logger.info('Found services: [{}]'.format(len(self.list)))

    def get_all_services(self):
        self.list = [ self.get_attrs(service, self.attr_names)
                            for service in self.wmiServices ]

    def __str__(self):
        return 'Total services: {}'.format(len(self.list))


class _RunningProcesses(_DeriveMe):

    def __init__(self, wmiInstance):
        self.attr_names = ['name', 'processid', 'executablepath', 'workingsetsize']
        self.wmiProcesses = wmiInstance.Win32_Process(self.attr_names)
        self.list = []
        self.dataFormat = 'flat'
        self.get_all_processes()
        logger.info('Found processes: [{}]'.format(len(self.list)))

    def get_all_processes(self):
        self.list = [ self.get_attrs(process, self.attr_names)
                        for process in self.wmiProcesses ]

    def __str__(self):
        return 'Total running processes: {}'.format(len(self.list))


class _Startup(_DeriveMe):

    def __init__(self, wmiInstance):
        self.wmiStartup = wmiInstance.Win32_StartupCommand()
        self.attr_names = ['location', 'caption', 'command']
        self.list = []
        self.dataFormat = 'flat'
        self.get_all_startup()
        logger.info('Found startup applications: [{}]'.format(len(self.list)))

    def get_all_startup(self):
        self.list = [ self.get_attrs(start, self.attr_names)
                            for start in self.wmiStartup ]

    def __str__(self):
        return 'Startup applications: {}'.format(len(self.list))


class _LogicalDisks(_DeriveMe):

    def __init__(self, wmiInstance):
        self.wmiDisks = wmiInstance.Win32_LogicalDisk()
        self.diskAlert = [['DeviceID', '% Free']]
        self.local = []
        self.network = []
        self.get_all_disks()

    # BEGIN inner classes for individual drives
    class Drive(_DeriveMe):
        def __init__(self, localDisk):
            self.className = 'Local'
            self.attr_names = ['deviceid', 'description', 'size', 'freespace', 'volumedirty']
            self.set_attrs(localDisk, self.attr_names)
            self.set_space()
        def set_space(self):
            if self.size:
                self.size = round(int(self.size) / 10.0 ** 9, 2)
                self.freespace = round(int(self.freespace) / 10.0 ** 9, 2)
                setattr(self, 'pctFree', round((100 * self.freespace / self.size), 2))
            else:
                self.size = 0
                self.freespace = 0
                setattr(self, 'pctFree', None)

    class MappedDrive(Drive):
        def __init__(self, mappedDrive):
            self.className = 'Remote'
            self.attr_names = ['deviceid', 'providername', 'description', 'size', 'freespace']
            self.set_attrs(mappedDrive, self.attr_names)
            self.set_space()
    # END inner class

    def get_all_disks(self):
        for disk in self.wmiDisks:
            logger.info('Found: {} {}'.format(disk.deviceid, disk.description))
            if disk.DriveType == 4:
                newDrive = self.MappedDrive(disk)
                self.network.append(newDrive)
            else:
                newDrive = self.Drive(disk)
                if 'ROM' not in newDrive.description and newDrive.size and newDrive.pctFree < 25:
                    self.diskAlert.append([newDrive.deviceid, newDrive.pctFree])
                    logger.warning('Disk alert: {} -- {}'.format(newDrive.deviceid,
                                                               newDrive.pctFree))
                self.local.append(newDrive)

    def output(self):
        logger.info('Generating output for Disks.')
        output = []
        output.append([ self.className ])
        output.extend(defaults.OUTPUT_SPACES)
        output.append(['Local Disks'])
        output.append(self.local[0].attr_names)
        output.extend([ i.get_attrs(i, i.attr_names) for i in self.local ])
        if self.network:
            output.extend(defaults.OUTPUT_SPACES)
            output.append(['Network Disks'])
            output.append(self.network[0].attr_names)
            output.extend([ i.get_attrs(i, i.attr_names) for i in self.network ])
        if self.diskAlert:
            output.extend(defaults.OUTPUT_SPACES)
            output.append(['Disk Alert'])
            output.extend(self.diskAlert)
        return output


class _EventLog(_DeriveMe):
    def __init__(self, wmiInstance, eventHistory):
        self.sumAttrs = ['logfile', 'sourcename', 'eventcode',
                         'type', 'message', 'first', 'last',
                         'count']
        self.attr_names = ['logfile', 'recordnumber', 'sourcename',
                           'eventcode', 'type', 'message',
                           'timegenerated']
        self.sumCache = {}
        self.summary = []
        self.complete = []
        self.count = 0
        self.data_format = ('list', ('summary', 'complete'))
        if eventHistory == -1:
            logger.debug('eventHistory is -1, skipping events.')
        else:
            self.query = self.build_event_query(self.attr_names, eventHistory)
            logger.debug('Event query: {}'.format(self.query))
            self.wmiEvents = wmiInstance.query(self.query)
            self.count = len(self.wmiEvents)
            logger.info('Found events: [{}]'.format(self.count))
            if self.count:
                logger.info('Processing events')
                self.process_events()
                logger.info('Finalizing events')
                self.finalize_events()
                logger.info('Done')

    def build_event_query(self, attributes, history):
        types = ' or '.join(['Type="{}"'.format(e)
                            for e in ['error', 'critical']])
        type_query = '({})'.format(types)
        logfile_query = '(Logfile="Application" or Logfile="System")'
        history_query = self.get_event_date(history)
        attributes = ','.join(attributes)
        return 'SELECT {} from Win32_NTLogEvent WHERE {} AND {}{}'.format(attributes,
                  type_query, logfile_query, history_query).encode('utf-8')

    def get_event_date(self, daysBack):
        if not daysBack:
            return ''
        goBackToDate = datetime.datetime.now() - datetime.timedelta(days=daysBack)
        formatted_date = datetime.date.strftime(goBackToDate, '%m/%d/%Y 00:00:00')
        return ' AND TimeGenerated>="{}"'.format(formatted_date)

    # Begin inner classes
    class Event(_DeriveMe):
        def __init__(self, event, attrs):
            self.attr_names = attrs
            self.set_attrs(event, attrs)
            self.className = 'Event'
            self.timegenerated = self.format_date_output(self.format_date(self.timegenerated))
            try:
                self.message = (self.message).decode('utf-8')
            except Exception as e:
                logger.warning('Could not decode message: {}'.format(e))
            self.keyString = '{} :: {} :: {}'.format(self.logfile, self.sourcename,
                                                     self.eventcode)

    class SumEvent(_DeriveMe):
        def __init__(self, event, attrs):
            self.className = 'Event Summary'
            self.attr_names = attrs
            self.set_attrs(event, attrs)
            self.first = event.timegenerated
            self.last = event.timegenerated
            self.count = 1
    # End inner classes

    def process_events(self):
        for event in self.wmiEvents:
            newEvent = self.Event(event, self.attr_names)
            self.complete.append(newEvent)
            self.update_cache(self.sumCache, newEvent)
            logger.debug('Processed event: {} {} {}'.format(newEvent.eventcode,
                            newEvent.sourcename, newEvent.recordnumber))
    def update_cache(self, cache, event):
        if event.keyString in cache:
            logger.debug('Event in cache')
            if event.timegenerated < cache[event.keyString].first:
                cache[event.keyString].first = event.timegenerated
            elif event.timegenerated > cache[event.keyString].last:
                cache[event.keyString].last = event.timegenerated
                cache[event.keyString].message = event.message
            cache[event.keyString].count += 1
        else:
            logger.debug('Adding event to cache')
            sumEvent = self.SumEvent(event, self.sumAttrs)
            cache[event.keyString] = sumEvent

    def finalize_events(self):
        self.summary = sorted(self.sumCache.values(), key=attrgetter('count'), reverse=True)
        self.complete.sort(key=attrgetter('timegenerated'), reverse=True)

    def output(self):
        logger.info('Generating output for EventLog.')
        outSummary = [self.summary[0].attr_names]
        outComplete = [self.complete[0].attr_names]
        outSummary.extend([i.get_attrs(i, i.attr_names) for i in self.summary])
        outComplete.extend([i.get_attrs(i, i.attr_names) for i in self.complete])
        return outSummary, outComplete